from __future__ import absolute_import

import logging
import ast
from datetime import datetime
from operator import itemgetter

import numpy as np
from rest_framework import status
from rest_framework.response import Response
from django.core.paginator import Paginator, PageNotAnInteger, EmptyPage

logger = logging.getLogger(__name__)


def es_search(doc, order, search_method, search_max=100, time_field=None, **kwargs):
    """使用es进行搜索
    :param doc: ElasticSearch索引对应的Document
    :param order: 查询之后的结果排序方式
    :param search_method: 每一个字段对应的搜索方式
    :param search_max: 允许搜索多少条数据, default=100
    :param kwargs: 筛选查询的字段名
    :return: 查询结果
    """
    time_field = [] if time_field is None else time_field
    search_condition = []
    for key in kwargs:
        if key not in search_method:
            continue
        field_search_method = search_method[key]
        if key in time_field:
            search_value = {
                "gte": kwargs[key][0],
                "lte": kwargs[key][1]
            }
        elif field_search_method == "wildcard":
            search_value = f'*{kwargs[key]}*'
        else:
            search_value = kwargs[key]
        search_condition.append({field_search_method: {key: search_value}})
    body = {"bool":{"filter": search_condition}}
    filter_res = doc.search().extra(size=search_max).query(body).sort(order).execute()
    return filter_res.hits


def db_search(db_table, order, search_method, search_max=100, **kwargs):
    """使用db进行搜索
    :param db_table: 搜索的db数据表
    :param order: 搜索结果的排序方式
    :param search_method: 每一个字段的搜索方法
    :param search_max: 搜索的最大数量, default=100
    :param kwargs: 搜索的字段
    :return: 查询结果
    """
    obj = db_table.objects
    body = {}
    for key in kwargs:
        if key not in search_method:
            continue
        field_search_method = search_method[key]
        body.update({f'{key}__{field_search_method}': kwargs[key]})
    filter_res = obj.filter(**body)
    return filter_res.order_by(order)[:search_max]


def get_union_field_values(results, search_engine, union_fields):
    """获取字段值
    :param results: 获得的搜索筛选结果
    :param search_engine: 获得搜索结果使用的搜索引擎 db/es
    :param union_field: 获得字段值的字段列表
    :return: 字段值的列表
    """
    all_field_values = []
    for result in results:
        search_res = result.__dict__ if search_engine == "db" else result.__dict__["_d_"]
        all_field_values.append(list(itemgetter(*union_fields)(search_res)))

    union_field_values = {}
    for idx, union_field in enumerate(union_fields):
        union_values = list(set(np.array(all_field_values)[:, idx]))
        if '' in union_values:
            union_values.remove('')
        if '--' in union_values:
            union_values.remove('--')

        union_field_values.update({union_field: union_values})
    return union_field_values


def check_body(body, time_field=None, fields=None):
    """检查转换筛选的字段
    :param body: 转换前的筛选字段
    :param time_field: 要转换的时间字段, 应该是一个字典{field_name: [ft_format, tt_format]}, 其中
        field_name: 字段名
        ft_format: 转换前的时间格式
        tt_format: 转换后的时间格式
    :param args: 其他要转换格式的字段
    :return: 转换后的字段dict
    """
    fields = [] if fields is None else fields
    res = {}
    for key in body:
        if len(key) == 0:
            continue
        if isinstance(body[key], list):
            while "" in body[key]:
                body[key].remove("")

        if len(body[key]) == 0:
            continue
        if key in fields:
            field = ast.literal_eval(body.get(key))
            res.update({key: field})
            continue
        if key in time_field:
            range_time = ast.literal_eval(body.get(key))
            _from = datetime.strftime(
                datetime.strptime(range_time[0], time_field[key][0]), time_field[key][1])
            to = datetime.strftime(
                datetime.strptime(range_time[1], time_field[key][0]), time_field[key][1])
            res.update({key: [_from, to]})
            continue
        res.update({key: body[key]})
    return res

def set_trace_view(self):
    search_method = self.template.get("search_method")
    self.es_search_method = {}
    self.db_search_method = {}
    for field in search_method:
        self.es_search_method.update({field: search_method.get(field)["es"]})
        self.db_search_method.update({field: search_method.get(field)["db"]})
    self.union_fields = self.template.get("union_fields")
    if "search_engine" in self.template:
        self.search_engine = self.template.get("search_engine")
    else:
        self.search_engine = "db"
    self.search_max = 999

    return

def get_page_list(search_res_list, page, page_size):
    paginator = Paginator(search_res_list, page_size)
    try:
        page_search_res_list = paginator.page(page)
    except PageNotAnInteger:
        page_search_res_list = paginator.page(1)
    except EmptyPage:
        page_search_res_list = paginator.page(paginator.num_pages)
    return page_search_res_list, paginator

def response_total(response, data_list, count, union_field_values):
    response.data["data_list"] = data_list
    response.data["count"] = count
    response.data["union_field_values"] = union_field_values
    return Response(response.dict, status=status.HTTP_200_OK)
